{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset\n",
    "\n",
    "Dataset is derived from Fannie Mae’s [Single-Family Loan Performance Data](http://www.fanniemae.com/portal/funding-the-market/data/loan-performance-data.html) with all rights reserved by Fannie Mae. Refer to these [instructions](https://github.com/NVIDIA/spark-rapids-examples/blob/branch-23.10/docs/get-started/xgboost-examples/dataset/mortgage.md) to download the dataset.\n",
    "\n",
    "# ETL + XGBoost train & transform\n",
    "\n",
    "This notebook is an end-to-end example of ETL + XGBoost Train & Transform by using [Spark-Rapids](https://github.com/NVIDIA/spark-rapids) and [XGBoost](https://github.com/dmlc/xgboost) with GPU accelerated.\n",
    "<br>The main steps:\n",
    "1. Run ETL to generate 2 datasets for train and test<br>\n",
    "   You can choose to save the datasets or not by setting \"is_save_dataset\" to True or False.<br>\n",
    "   It means you don't need to save the dataset to disk after ETL and directly feed the dataframe to XGBoost train or transform.\n",
    "2. Run XGBoost train with the train dataset\n",
    "3. Run XGBoost transform with the test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "from pyspark import broadcast\n",
    "from pyspark.conf import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import *\n",
    "from pyspark.sql.types import *\n",
    "from pyspark.sql.window import Window\n",
    "# if you pass/unpack the archive file and enable the environment\n",
    "# os.environ['PYSPARK_PYTHON'] = \"./environment/bin/python\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Part\n",
    "### 1. Define the paths\n",
    "You need to update them to your real paths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The input path of dataset\n",
    "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n",
    "orig_raw_path = dataRoot + \"/mortgage/input/\"\n",
    "orig_raw_path_csv2parquet = dataRoot + \"/mortgage/output/csv2parquet/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"/your-url\")\n",
    "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-jar-path\")\n",
    "\n",
    "# You need to update with your real hardware resource \n",
    "driverMem = os.getenv(\"DRIVER_MEM\", \"10g\")\n",
    "executorMem = os.getenv(\"EXECUTOR_MEM\", \"10g\")\n",
    "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"2g\")\n",
    "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"2\")\n",
    "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"4\"))\n",
    "\n",
    "# Common spark settings\n",
    "conf = SparkConf()\n",
    "conf.setMaster(SPARK_MASTER_URL)\n",
    "conf.setAppName(\"Microbenchmark on GPU\")\n",
    "conf.set(\"spark.driver.memory\", driverMem)\n",
    "## The tasks will run on GPU memory, so there is no need to set a high host memory\n",
    "conf.set(\"spark.executor.memory\", executorMem)\n",
    "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n",
    "conf.set(\"spark.executor.cores\", executorCores)\n",
    "\n",
    "# Plugin settings\n",
    "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n",
    "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n",
    "##############note: only support value=1 see https://github.com/dmlc/xgboost/blame/master/python-package/xgboost/spark/core.py#L370-L374\n",
    "conf.set(\"spark.task.resource.gpu.amount\", 1) \n",
    "# since pyspark and xgboost share the same GPU, we need to allocate some memory to xgboost to avoid GPU OOM while training \n",
    "conf.set(\"spark.rapids.memory.gpu.allocFraction\",\"0.6\")\n",
    "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n",
    "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n",
    "conf.set(\"spark.sql.cache.serializer\",\"com.nvidia.spark.ParquetCachedBatchSerializer\")\n",
    "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", 200000) \n",
    "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n",
    "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n",
    "conf.set(\"spark.jars\", RAPIDS_JAR)\n",
    "\n",
    "# if you pass/unpack the archive file and enable the environment\n",
    "# conf.set(\"spark.yarn.dist.archives\", \"your_pyspark_venv.tar.gz#environment\")\n",
    "\n",
    "# Create spark session\n",
    "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n",
    "reader = spark.read"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set True to save processed dataset after ETL\n",
    "# Set False, the dataset after ETL will be directly used in XGBoost train and transform\n",
    "\n",
    "is_save_dataset=True\n",
    "output_path_data=dataRoot + \"/mortgage/output/data/\"\n",
    "# the path to save the xgboost model\n",
    "output_path_model=dataRoot + \"/mortgage/output/model/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Define the constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# File schema\n",
    "\n",
    "_csv_raw_schema = StructType([\n",
    "      StructField(\"reference_pool_id\", StringType()),\n",
    "      StructField(\"loan_id\", LongType()),\n",
    "      StructField(\"monthly_reporting_period\", StringType()),\n",
    "      StructField(\"orig_channel\", StringType()),\n",
    "      StructField(\"seller_name\", StringType()),\n",
    "      StructField(\"servicer\", StringType()),\n",
    "      StructField(\"master_servicer\", StringType()),\n",
    "      StructField(\"orig_interest_rate\", DoubleType()),\n",
    "      StructField(\"interest_rate\", DoubleType()),\n",
    "      StructField(\"orig_upb\", DoubleType()),\n",
    "      StructField(\"upb_at_issuance\", StringType()),\n",
    "      StructField(\"current_actual_upb\", DoubleType()),\n",
    "      StructField(\"orig_loan_term\", IntegerType()),\n",
    "      StructField(\"orig_date\", StringType()),\n",
    "      StructField(\"first_pay_date\", StringType()),    \n",
    "      StructField(\"loan_age\", DoubleType()),\n",
    "      StructField(\"remaining_months_to_legal_maturity\", DoubleType()),\n",
    "      StructField(\"adj_remaining_months_to_maturity\", DoubleType()),\n",
    "      StructField(\"maturity_date\", StringType()),\n",
    "      StructField(\"orig_ltv\", DoubleType()),\n",
    "      StructField(\"orig_cltv\", DoubleType()),\n",
    "      StructField(\"num_borrowers\", DoubleType()),\n",
    "      StructField(\"dti\", DoubleType()),\n",
    "      StructField(\"borrower_credit_score\", DoubleType()),\n",
    "      StructField(\"coborrow_credit_score\", DoubleType()),\n",
    "      StructField(\"first_home_buyer\", StringType()),\n",
    "      StructField(\"loan_purpose\", StringType()),\n",
    "      StructField(\"property_type\", StringType()),\n",
    "      StructField(\"num_units\", IntegerType()),\n",
    "      StructField(\"occupancy_status\", StringType()),\n",
    "      StructField(\"property_state\", StringType()),\n",
    "      StructField(\"msa\", DoubleType()),\n",
    "      StructField(\"zip\", IntegerType()),\n",
    "      StructField(\"mortgage_insurance_percent\", DoubleType()),\n",
    "      StructField(\"product_type\", StringType()),\n",
    "      StructField(\"prepayment_penalty_indicator\", StringType()),\n",
    "      StructField(\"interest_only_loan_indicator\", StringType()),\n",
    "      StructField(\"interest_only_first_principal_and_interest_payment_date\", StringType()),\n",
    "      StructField(\"months_to_amortization\", StringType()),\n",
    "      StructField(\"current_loan_delinquency_status\", IntegerType()),\n",
    "      StructField(\"loan_payment_history\", StringType()),\n",
    "      StructField(\"mod_flag\", StringType()),\n",
    "      StructField(\"mortgage_insurance_cancellation_indicator\", StringType()),\n",
    "      StructField(\"zero_balance_code\", StringType()),\n",
    "      StructField(\"zero_balance_effective_date\", StringType()),\n",
    "      StructField(\"upb_at_the_time_of_removal\", StringType()),\n",
    "      StructField(\"repurchase_date\", StringType()),\n",
    "      StructField(\"scheduled_principal_current\", StringType()),\n",
    "      StructField(\"total_principal_current\", StringType()),\n",
    "      StructField(\"unscheduled_principal_current\", StringType()),\n",
    "      StructField(\"last_paid_installment_date\", StringType()),\n",
    "      StructField(\"foreclosed_after\", StringType()),\n",
    "      StructField(\"disposition_date\", StringType()),\n",
    "      StructField(\"foreclosure_costs\", DoubleType()),\n",
    "      StructField(\"prop_preservation_and_repair_costs\", DoubleType()),\n",
    "      StructField(\"asset_recovery_costs\", DoubleType()),\n",
    "      StructField(\"misc_holding_expenses\", DoubleType()),\n",
    "      StructField(\"holding_taxes\", DoubleType()),\n",
    "      StructField(\"net_sale_proceeds\", DoubleType()),\n",
    "      StructField(\"credit_enhancement_proceeds\", DoubleType()),\n",
    "      StructField(\"repurchase_make_whole_proceeds\", StringType()),\n",
    "      StructField(\"other_foreclosure_proceeds\", DoubleType()),\n",
    "      StructField(\"non_interest_bearing_upb\", DoubleType()),\n",
    "      StructField(\"principal_forgiveness_upb\", StringType()),\n",
    "      StructField(\"original_list_start_date\", StringType()),\n",
    "      StructField(\"original_list_price\", StringType()),\n",
    "      StructField(\"current_list_start_date\", StringType()),\n",
    "      StructField(\"current_list_price\", StringType()),\n",
    "      StructField(\"borrower_credit_score_at_issuance\", StringType()),\n",
    "      StructField(\"co-borrower_credit_score_at_issuance\", StringType()),\n",
    "      StructField(\"borrower_credit_score_current\", StringType()),\n",
    "      StructField(\"co-Borrower_credit_score_current\", StringType()),\n",
    "      StructField(\"mortgage_insurance_type\", DoubleType()),\n",
    "      StructField(\"servicing_activity_indicator\", StringType()),\n",
    "      StructField(\"current_period_modification_loss_amount\", StringType()),\n",
    "      StructField(\"cumulative_modification_loss_amount\", StringType()),\n",
    "      StructField(\"current_period_credit_event_net_gain_or_loss\", StringType()),\n",
    "      StructField(\"cumulative_credit_event_net_gain_or_loss\", StringType()),\n",
    "      StructField(\"homeready_program_indicator\", StringType()),\n",
    "      StructField(\"foreclosure_principal_write_off_amount\", StringType()),\n",
    "      StructField(\"relocation_mortgage_indicator\", StringType()),\n",
    "      StructField(\"zero_balance_code_change_date\", StringType()),\n",
    "      StructField(\"loan_holdback_indicator\", StringType()),\n",
    "      StructField(\"loan_holdback_effective_date\", StringType()),\n",
    "      StructField(\"delinquent_accrued_interest\", StringType()),\n",
    "      StructField(\"property_valuation_method\", StringType()),\n",
    "      StructField(\"high_balance_loan_indicator\", StringType()),\n",
    "      StructField(\"arm_initial_fixed-rate_period_lt_5_yr_indicator\", StringType()),\n",
    "      StructField(\"arm_product_type\", StringType()),\n",
    "      StructField(\"initial_fixed-rate_period\", StringType()),\n",
    "      StructField(\"interest_rate_adjustment_frequency\", StringType()),\n",
    "      StructField(\"next_interest_rate_adjustment_date\", StringType()),\n",
    "      StructField(\"next_payment_change_date\", StringType()),\n",
    "      StructField(\"index\", StringType()),\n",
    "      StructField(\"arm_cap_structure\", StringType()),\n",
    "      StructField(\"initial_interest_rate_cap_up_percent\", StringType()),\n",
    "      StructField(\"periodic_interest_rate_cap_up_percent\", StringType()),\n",
    "      StructField(\"lifetime_interest_rate_cap_up_percent\", StringType()),\n",
    "      StructField(\"mortgage_margin\", StringType()),\n",
    "      StructField(\"arm_balloon_indicator\", StringType()),\n",
    "      StructField(\"arm_plan_number\", StringType()),\n",
    "      StructField(\"borrower_assistance_plan\", StringType()),\n",
    "      StructField(\"hltv_refinance_option_indicator\", StringType()),\n",
    "      StructField(\"deal_name\", StringType()),\n",
    "      StructField(\"repurchase_make_whole_proceeds_flag\", StringType()),\n",
    "      StructField(\"alternative_delinquency_resolution\", StringType()),\n",
    "      StructField(\"alternative_delinquency_resolution_count\", StringType()),\n",
    "      StructField(\"total_deferral_amount\", StringType())\n",
    "      ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# name mappings\n",
    "_name_mapping = [\n",
    "        (\"WITMER FUNDING, LLC\", \"Witmer\"),\n",
    "        (\"WELLS FARGO CREDIT RISK TRANSFER SECURITIES TRUST 2015\", \"Wells Fargo\"),\n",
    "        (\"WELLS FARGO BANK,  NA\" , \"Wells Fargo\"),\n",
    "        (\"WELLS FARGO BANK, N.A.\" , \"Wells Fargo\"),\n",
    "        (\"WELLS FARGO BANK, NA\" , \"Wells Fargo\"),\n",
    "        (\"USAA FEDERAL SAVINGS BANK\" , \"USAA\"),\n",
    "        (\"UNITED SHORE FINANCIAL SERVICES, LLC D\\\\/B\\\\/A UNITED WHOLESALE MORTGAGE\" , \"United Seq(e\"),\n",
    "        (\"U.S. BANK N.A.\" , \"US Bank\"),\n",
    "        (\"SUNTRUST MORTGAGE INC.\" , \"Suntrust\"),\n",
    "        (\"STONEGATE MORTGAGE CORPORATION\" , \"Stonegate Mortgage\"),\n",
    "        (\"STEARNS LENDING, LLC\" , \"Stearns Lending\"),\n",
    "        (\"STEARNS LENDING, INC.\" , \"Stearns Lending\"),\n",
    "        (\"SIERRA PACIFIC MORTGAGE COMPANY, INC.\" , \"Sierra Pacific Mortgage\"),\n",
    "        (\"REGIONS BANK\" , \"Regions\"),\n",
    "        (\"RBC MORTGAGE COMPANY\" , \"RBC\"),\n",
    "        (\"QUICKEN LOANS INC.\" , \"Quicken Loans\"),\n",
    "        (\"PULTE MORTGAGE, L.L.C.\" , \"Pulte Mortgage\"),\n",
    "        (\"PROVIDENT FUNDING ASSOCIATES, L.P.\" , \"Provident Funding\"),\n",
    "        (\"PROSPECT MORTGAGE, LLC\" , \"Prospect Mortgage\"),\n",
    "        (\"PRINCIPAL RESIDENTIAL MORTGAGE CAPITAL RESOURCES, LLC\" , \"Principal Residential\"),\n",
    "        (\"PNC BANK, N.A.\" , \"PNC\"),\n",
    "        (\"PMT CREDIT RISK TRANSFER TRUST 2015-2\" , \"PennyMac\"),\n",
    "        (\"PHH MORTGAGE CORPORATION\" , \"PHH Mortgage\"),\n",
    "        (\"PENNYMAC CORP.\" , \"PennyMac\"),\n",
    "        (\"PACIFIC UNION FINANCIAL, LLC\" , \"Other\"),\n",
    "        (\"OTHER\" , \"Other\"),\n",
    "        (\"NYCB MORTGAGE COMPANY, LLC\" , \"NYCB\"),\n",
    "        (\"NEW YORK COMMUNITY BANK\" , \"NYCB\"),\n",
    "        (\"NETBANK FUNDING SERVICES\" , \"Netbank\"),\n",
    "        (\"NATIONSTAR MORTGAGE, LLC\" , \"Nationstar Mortgage\"),\n",
    "        (\"METLIFE BANK, NA\" , \"Metlife\"),\n",
    "        (\"LOANDEPOT.COM, LLC\" , \"LoanDepot.com\"),\n",
    "        (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2015-1\" , \"JP Morgan Chase\"),\n",
    "        (\"J.P. MORGAN MADISON AVENUE SECURITIES TRUST, SERIES 2014-1\" , \"JP Morgan Chase\"),\n",
    "        (\"JPMORGAN CHASE BANK, NATIONAL ASSOCIATION\" , \"JP Morgan Chase\"),\n",
    "        (\"JPMORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n",
    "        (\"JP MORGAN CHASE BANK, NA\" , \"JP Morgan Chase\"),\n",
    "        (\"IRWIN MORTGAGE, CORPORATION\" , \"Irwin Mortgage\"),\n",
    "        (\"IMPAC MORTGAGE CORP.\" , \"Impac Mortgage\"),\n",
    "        (\"HSBC BANK USA, NATIONAL ASSOCIATION\" , \"HSBC\"),\n",
    "        (\"HOMEWARD RESIDENTIAL, INC.\" , \"Homeward Mortgage\"),\n",
    "        (\"HOMESTREET BANK\" , \"Other\"),\n",
    "        (\"HOMEBRIDGE FINANCIAL SERVICES, INC.\" , \"HomeBridge\"),\n",
    "        (\"HARWOOD STREET FUNDING I, LLC\" , \"Harwood Mortgage\"),\n",
    "        (\"GUILD MORTGAGE COMPANY\" , \"Guild Mortgage\"),\n",
    "        (\"GMAC MORTGAGE, LLC (USAA FEDERAL SAVINGS BANK)\" , \"GMAC\"),\n",
    "        (\"GMAC MORTGAGE, LLC\" , \"GMAC\"),\n",
    "        (\"GMAC (USAA)\" , \"GMAC\"),\n",
    "        (\"FREMONT BANK\" , \"Fremont Bank\"),\n",
    "        (\"FREEDOM MORTGAGE CORP.\" , \"Freedom Mortgage\"),\n",
    "        (\"FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"Franklin America\"),\n",
    "        (\"FLEET NATIONAL BANK\" , \"Fleet National\"),\n",
    "        (\"FLAGSTAR CAPITAL MARKETS CORPORATION\" , \"Flagstar Bank\"),\n",
    "        (\"FLAGSTAR BANK, FSB\" , \"Flagstar Bank\"),\n",
    "        (\"FIRST TENNESSEE BANK NATIONAL ASSOCIATION\" , \"Other\"),\n",
    "        (\"FIFTH THIRD BANK\" , \"Fifth Third Bank\"),\n",
    "        (\"FEDERAL HOME LOAN BANK OF CHICAGO\" , \"Fedral Home of Chicago\"),\n",
    "        (\"FDIC, RECEIVER, INDYMAC FEDERAL BANK FSB\" , \"FDIC\"),\n",
    "        (\"DOWNEY SAVINGS AND LOAN ASSOCIATION, F.A.\" , \"Downey Mortgage\"),\n",
    "        (\"DITECH FINANCIAL LLC\" , \"Ditech\"),\n",
    "        (\"CITIMORTGAGE, INC.\" , \"Citi\"),\n",
    "        (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERFIRST MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n",
    "        (\"CHICAGO MORTGAGE SOLUTIONS DBA INTERBANK MORTGAGE COMPANY\" , \"Chicago Mortgage\"),\n",
    "        (\"CHASE HOME FINANCE, LLC\" , \"JP Morgan Chase\"),\n",
    "        (\"CHASE HOME FINANCE FRANKLIN AMERICAN MORTGAGE COMPANY\" , \"JP Morgan Chase\"),\n",
    "        (\"CHASE HOME FINANCE (CIE 1)\" , \"JP Morgan Chase\"),\n",
    "        (\"CHASE HOME FINANCE\" , \"JP Morgan Chase\"),\n",
    "        (\"CASHCALL, INC.\" , \"CashCall\"),\n",
    "        (\"CAPITAL ONE, NATIONAL ASSOCIATION\" , \"Capital One\"),\n",
    "        (\"CALIBER HOME LOANS, INC.\" , \"Caliber Funding\"),\n",
    "        (\"BISHOPS GATE RESIDENTIAL MORTGAGE TRUST\" , \"Bishops Gate Mortgage\"),\n",
    "        (\"BANK OF AMERICA, N.A.\" , \"Bank of America\"),\n",
    "        (\"AMTRUST BANK\" , \"AmTrust\"),\n",
    "        (\"AMERISAVE MORTGAGE CORPORATION\" , \"Amerisave\"),\n",
    "        (\"AMERIHOME MORTGAGE COMPANY, LLC\" , \"AmeriHome Mortgage\"),\n",
    "        (\"ALLY BANK\" , \"Ally Bank\"),\n",
    "        (\"ACADEMY MORTGAGE CORPORATION\" , \"Academy Mortgage\"),\n",
    "        (\"NO CASH-OUT REFINANCE\" , \"OTHER REFINANCE\"),\n",
    "        (\"REFINANCE - NOT SPECIFIED\" , \"OTHER REFINANCE\"),\n",
    "        (\"Other REFINANCE\" , \"OTHER REFINANCE\")]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# String columns\n",
    "cate_col_names = [\n",
    "        \"orig_channel\",\n",
    "        \"first_home_buyer\",\n",
    "        \"loan_purpose\",\n",
    "        \"property_type\",\n",
    "        \"occupancy_status\",\n",
    "        \"property_state\",\n",
    "        \"product_type\",\n",
    "        \"relocation_mortgage_indicator\",\n",
    "        \"seller_name\",\n",
    "        \"mod_flag\"\n",
    "]\n",
    "# Numeric columns\n",
    "label_col_name = \"delinquency_12\"\n",
    "numeric_col_names = [\n",
    "        \"orig_interest_rate\",\n",
    "        \"orig_upb\",\n",
    "        \"orig_loan_term\",\n",
    "        \"orig_ltv\",\n",
    "        \"orig_cltv\",\n",
    "        \"num_borrowers\",\n",
    "        \"dti\",\n",
    "        \"borrower_credit_score\",\n",
    "        \"num_units\",\n",
    "        \"zip\",\n",
    "        \"mortgage_insurance_percent\",\n",
    "        \"current_loan_delinquency_status\",\n",
    "        \"current_actual_upb\",\n",
    "        \"interest_rate\",\n",
    "        \"loan_age\",\n",
    "        \"msa\",\n",
    "        \"non_interest_bearing_upb\",\n",
    "        label_col_name\n",
    "]\n",
    "all_col_names = cate_col_names + numeric_col_names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Define ETL Process\n",
    "\n",
    "Define the function to do the ETL process\n",
    "\n",
    "#### 3.1 Define Functions to Read Raw CSV File\n",
    "\n",
    "* Define function to get quarter from input CSV file name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_quarter_from_csv_file_name():\n",
    "    return substring_index(substring_index(input_file_name(), \".\", 1), \"/\", -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Define function to read raw CSV data file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_raw_csv(spark, path):\n",
    "    return spark.read.format('csv') \\\n",
    "            .option('nullValue', '') \\\n",
    "            .option('header', False) \\\n",
    "            .option('delimiter', '|') \\\n",
    "            .schema(_csv_raw_schema) \\\n",
    "            .load(path) \\\n",
    "            .withColumn('quarter', _get_quarter_from_csv_file_name())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Functions to extract perf and acq columns from raw schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_perf_columns(rawDf):\n",
    "    perfDf = rawDf.select(\n",
    "      col(\"loan_id\"),\n",
    "      date_format(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"monthly_reporting_period\"),\n",
    "      upper(col(\"servicer\")).alias(\"servicer\"),\n",
    "      col(\"interest_rate\"),\n",
    "      col(\"current_actual_upb\"),\n",
    "      col(\"loan_age\"),\n",
    "      col(\"remaining_months_to_legal_maturity\"),\n",
    "      col(\"adj_remaining_months_to_maturity\"),\n",
    "      date_format(to_date(col(\"maturity_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"maturity_date\"),\n",
    "      col(\"msa\"),\n",
    "      col(\"current_loan_delinquency_status\"),\n",
    "      col(\"mod_flag\"),\n",
    "      col(\"zero_balance_code\"),\n",
    "      date_format(to_date(col(\"zero_balance_effective_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"zero_balance_effective_date\"),\n",
    "      date_format(to_date(col(\"last_paid_installment_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"last_paid_installment_date\"),\n",
    "      date_format(to_date(col(\"foreclosed_after\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"foreclosed_after\"),\n",
    "      date_format(to_date(col(\"disposition_date\"),\"MMyyyy\"), \"MM/dd/yyyy\").alias(\"disposition_date\"),\n",
    "      col(\"foreclosure_costs\"),\n",
    "      col(\"prop_preservation_and_repair_costs\"),\n",
    "      col(\"asset_recovery_costs\"),\n",
    "      col(\"misc_holding_expenses\"),\n",
    "      col(\"holding_taxes\"),\n",
    "      col(\"net_sale_proceeds\"),\n",
    "      col(\"credit_enhancement_proceeds\"),\n",
    "      col(\"repurchase_make_whole_proceeds\"),\n",
    "      col(\"other_foreclosure_proceeds\"),\n",
    "      col(\"non_interest_bearing_upb\"),\n",
    "      col(\"principal_forgiveness_upb\"),\n",
    "      col(\"repurchase_make_whole_proceeds_flag\"),\n",
    "      col(\"foreclosure_principal_write_off_amount\"),\n",
    "      col(\"servicing_activity_indicator\"),\n",
    "      col('quarter')\n",
    "    )\n",
    "    return perfDf.select(\"*\").filter(\"current_actual_upb != 0.0\")\n",
    "\n",
    "def extract_acq_columns(rawDf):\n",
    "    acqDf = rawDf.select(\n",
    "      col(\"loan_id\"),\n",
    "      col(\"orig_channel\"),\n",
    "      upper(col(\"seller_name\")).alias(\"seller_name\"),\n",
    "      col(\"orig_interest_rate\"),\n",
    "      col(\"orig_upb\"),\n",
    "      col(\"orig_loan_term\"),\n",
    "      date_format(to_date(col(\"orig_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"orig_date\"),\n",
    "      date_format(to_date(col(\"first_pay_date\"),\"MMyyyy\"), \"MM/yyyy\").alias(\"first_pay_date\"),\n",
    "      col(\"orig_ltv\"),\n",
    "      col(\"orig_cltv\"),\n",
    "      col(\"num_borrowers\"),\n",
    "      col(\"dti\"),\n",
    "      col(\"borrower_credit_score\"),\n",
    "      col(\"first_home_buyer\"),\n",
    "      col(\"loan_purpose\"),\n",
    "      col(\"property_type\"),\n",
    "      col(\"num_units\"),\n",
    "      col(\"occupancy_status\"),\n",
    "      col(\"property_state\"),\n",
    "      col(\"zip\"),\n",
    "      col(\"mortgage_insurance_percent\"),\n",
    "      col(\"product_type\"),\n",
    "      col(\"coborrow_credit_score\"),\n",
    "      col(\"mortgage_insurance_type\"),\n",
    "      col(\"relocation_mortgage_indicator\"),\n",
    "      dense_rank().over(Window.partitionBy(\"loan_id\").orderBy(to_date(col(\"monthly_reporting_period\"),\"MMyyyy\"))).alias(\"rank\"),\n",
    "      col('quarter')\n",
    "      )\n",
    "\n",
    "    return acqDf.select(\"*\").filter(col(\"rank\")==1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2 Define ETL Process\n",
    "\n",
    "* Define function to parse dates in Performance data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_dates(perf):\n",
    "    return perf \\\n",
    "            .withColumn(\"monthly_reporting_period\", to_date(col(\"monthly_reporting_period\"), \"MM/dd/yyyy\")) \\\n",
    "            .withColumn(\"monthly_reporting_period_month\", month(col(\"monthly_reporting_period\"))) \\\n",
    "            .withColumn(\"monthly_reporting_period_year\", year(col(\"monthly_reporting_period\"))) \\\n",
    "            .withColumn(\"monthly_reporting_period_day\", dayofmonth(col(\"monthly_reporting_period\"))) \\\n",
    "            .withColumn(\"last_paid_installment_date\", to_date(col(\"last_paid_installment_date\"), \"MM/dd/yyyy\")) \\\n",
    "            .withColumn(\"foreclosed_after\", to_date(col(\"foreclosed_after\"), \"MM/dd/yyyy\")) \\\n",
    "            .withColumn(\"disposition_date\", to_date(col(\"disposition_date\"), \"MM/dd/yyyy\")) \\\n",
    "            .withColumn(\"maturity_date\", to_date(col(\"maturity_date\"), \"MM/yyyy\")) \\\n",
    "            .withColumn(\"zero_balance_effective_date\", to_date(col(\"zero_balance_effective_date\"), \"MM/yyyy\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Define function to create deliquency data frame from Performance data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _create_perf_deliquency(spark, perf):\n",
    "    aggDF = perf.select(\n",
    "            col(\"quarter\"),\n",
    "            col(\"loan_id\"),\n",
    "            col(\"current_loan_delinquency_status\"),\n",
    "            when(col(\"current_loan_delinquency_status\") >= 1, col(\"monthly_reporting_period\")).alias(\"delinquency_30\"),\n",
    "            when(col(\"current_loan_delinquency_status\") >= 3, col(\"monthly_reporting_period\")).alias(\"delinquency_90\"),\n",
    "            when(col(\"current_loan_delinquency_status\") >= 6, col(\"monthly_reporting_period\")).alias(\"delinquency_180\")) \\\n",
    "            .groupBy(\"quarter\", \"loan_id\") \\\n",
    "            .agg(\n",
    "                max(\"current_loan_delinquency_status\").alias(\"delinquency_12\"),\n",
    "                min(\"delinquency_30\").alias(\"delinquency_30\"),\n",
    "                min(\"delinquency_90\").alias(\"delinquency_90\"),\n",
    "                min(\"delinquency_180\").alias(\"delinquency_180\")) \\\n",
    "            .select(\n",
    "                col(\"quarter\"),\n",
    "                col(\"loan_id\"),\n",
    "                (col(\"delinquency_12\") >= 1).alias(\"ever_30\"),\n",
    "                (col(\"delinquency_12\") >= 3).alias(\"ever_90\"),\n",
    "                (col(\"delinquency_12\") >= 6).alias(\"ever_180\"),\n",
    "                col(\"delinquency_30\"),\n",
    "                col(\"delinquency_90\"),\n",
    "                col(\"delinquency_180\"))\n",
    "    joinedDf = perf \\\n",
    "            .withColumnRenamed(\"monthly_reporting_period\", \"timestamp\") \\\n",
    "            .withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n",
    "            .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n",
    "            .withColumnRenamed(\"current_loan_delinquency_status\", \"delinquency_12\") \\\n",
    "            .withColumnRenamed(\"current_actual_upb\", \"upb_12\") \\\n",
    "            .select(\"quarter\", \"loan_id\", \"timestamp\", \"delinquency_12\", \"upb_12\", \"timestamp_month\", \"timestamp_year\") \\\n",
    "            .join(aggDF, [\"loan_id\", \"quarter\"], \"left_outer\")\n",
    "\n",
    "    # calculate the 12 month delinquency and upb values\n",
    "    months = 12\n",
    "    monthArray = [lit(x) for x in range(0, 12)]\n",
    "    # explode on a small amount of data is actually slightly more efficient than a cross join\n",
    "    testDf = joinedDf \\\n",
    "            .withColumn(\"month_y\", explode(array(monthArray))) \\\n",
    "            .select(\n",
    "                    col(\"quarter\"),\n",
    "                    floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000) / months).alias(\"josh_mody\"),\n",
    "                    floor(((col(\"timestamp_year\") * 12 + col(\"timestamp_month\")) - 24000 - col(\"month_y\")) / months).alias(\"josh_mody_n\"),\n",
    "                    col(\"ever_30\"),\n",
    "                    col(\"ever_90\"),\n",
    "                    col(\"ever_180\"),\n",
    "                    col(\"delinquency_30\"),\n",
    "                    col(\"delinquency_90\"),\n",
    "                    col(\"delinquency_180\"),\n",
    "                    col(\"loan_id\"),\n",
    "                    col(\"month_y\"),\n",
    "                    col(\"delinquency_12\"),\n",
    "                    col(\"upb_12\")) \\\n",
    "            .groupBy(\"quarter\", \"loan_id\", \"josh_mody_n\", \"ever_30\", \"ever_90\", \"ever_180\", \"delinquency_30\", \"delinquency_90\", \"delinquency_180\", \"month_y\") \\\n",
    "            .agg(max(\"delinquency_12\").alias(\"delinquency_12\"), min(\"upb_12\").alias(\"upb_12\")) \\\n",
    "            .withColumn(\"timestamp_year\", floor((lit(24000) + (col(\"josh_mody_n\") * lit(months)) + (col(\"month_y\") - 1)) / lit(12))) \\\n",
    "            .selectExpr(\"*\", \"pmod(24000 + (josh_mody_n * {}) + month_y, 12) as timestamp_month_tmp\".format(months)) \\\n",
    "            .withColumn(\"timestamp_month\", when(col(\"timestamp_month_tmp\") == lit(0), lit(12)).otherwise(col(\"timestamp_month_tmp\"))) \\\n",
    "            .withColumn(\"delinquency_12\", ((col(\"delinquency_12\") > 3).cast(\"int\") + (col(\"upb_12\") == 0).cast(\"int\")).alias(\"delinquency_12\")) \\\n",
    "            .drop(\"timestamp_month_tmp\", \"josh_mody_n\", \"month_y\")\n",
    "\n",
    "    return perf.withColumnRenamed(\"monthly_reporting_period_month\", \"timestamp_month\") \\\n",
    "            .withColumnRenamed(\"monthly_reporting_period_year\", \"timestamp_year\") \\\n",
    "            .join(testDf, [\"quarter\", \"loan_id\", \"timestamp_year\", \"timestamp_month\"], \"left\") \\\n",
    "            .drop(\"timestamp_year\", \"timestamp_month\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Define function to create acquisition data frame from Acquisition data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _create_acquisition(spark, acq):\n",
    "    nameMapping = spark.createDataFrame(_name_mapping, [\"from_seller_name\", \"to_seller_name\"])\n",
    "    return acq.join(nameMapping, col(\"seller_name\") == col(\"from_seller_name\"), \"left\") \\\n",
    "      .drop(\"from_seller_name\") \\\n",
    "      .withColumn(\"old_name\", col(\"seller_name\")) \\\n",
    "      .withColumn(\"seller_name\", coalesce(col(\"to_seller_name\"), col(\"seller_name\"))) \\\n",
    "      .drop(\"to_seller_name\") \\\n",
    "      .withColumn(\"orig_date\", to_date(col(\"orig_date\"), \"MM/yyyy\")) \\\n",
    "      .withColumn(\"first_pay_date\", to_date(col(\"first_pay_date\"), \"MM/yyyy\")) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.3 Define Casting Process\n",
    "This part is casting String column to Numeric one. \n",
    "Example:\n",
    "```\n",
    "col_1\n",
    " \"a\"\n",
    " \"b\"\n",
    " \"c\"\n",
    " \"a\"\n",
    "# After String ====> Numeric\n",
    "col_1\n",
    " 0\n",
    " 1\n",
    " 2\n",
    " 0\n",
    "```  \n",
    "<br>\n",
    "\n",
    "* Define function to get column dictionary\n",
    "\n",
    "    Example\n",
    "    ```\n",
    "    col1 = [row(data=\"a\",id=0), row(data=\"b\",id=1)]\n",
    "    ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _gen_dictionary(etl_df, col_names):\n",
    "    cnt_table = etl_df.select(posexplode(array([col(i) for i in col_names])))\\\n",
    "                    .withColumnRenamed(\"pos\", \"column_id\")\\\n",
    "                    .withColumnRenamed(\"col\", \"data\")\\\n",
    "                    .filter(\"data is not null\")\\\n",
    "                    .groupBy(\"column_id\", \"data\")\\\n",
    "                    .count()\n",
    "    windowed = Window.partitionBy(\"column_id\").orderBy(desc(\"count\"))\n",
    "    return cnt_table.withColumn(\"id\", row_number().over(windowed)).drop(\"count\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Define function to convert string columns to numeric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _cast_string_columns_to_numeric(spark, input_df):\n",
    "    cached_dict_df = _gen_dictionary(input_df, cate_col_names).cache()\n",
    "    output_df = input_df\n",
    "    #  Generate the final table with all columns being numeric.\n",
    "    for col_pos, col_name in enumerate(cate_col_names):\n",
    "        col_dict_df = cached_dict_df.filter(col(\"column_id\") == col_pos)\\\n",
    "                                    .drop(\"column_id\")\\\n",
    "                                    .withColumnRenamed(\"data\", col_name)\n",
    "        \n",
    "        output_df = output_df.join(broadcast(col_dict_df), col_name, \"left\")\\\n",
    "                        .drop(col_name)\\\n",
    "                        .withColumnRenamed(\"id\", col_name)\n",
    "    return output_df        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.4 Define Main Function\n",
    "In this function:\n",
    "1. Parse date in Performance data by calling _parse_dates (parsed_perf)\n",
    "2. Create deliqency dataframe(perf_deliqency) form Performance data by calling _create_perf_deliquency\n",
    "3. Create cleaned acquisition dataframe(cleaned_acq) from Acquisition data by calling _create_acquisition\n",
    "4. Join deliqency dataframe(perf_deliqency) and cleaned acquisition dataframe(cleaned_acq), get clean_df\n",
    "5. Cast String column to Numeric in clean_df by calling _cast_string_columns_to_numeric, get casted_clean_df\n",
    "6. Return casted_clean_df as final result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_mortgage(spark, perf, acq):\n",
    "    parsed_perf = _parse_dates(perf)\n",
    "    perf_deliqency = _create_perf_deliquency(spark, parsed_perf)\n",
    "    cleaned_acq = _create_acquisition(spark, acq)\n",
    "    clean_df = perf_deliqency.join(cleaned_acq, [\"loan_id\", \"quarter\"], \"inner\").drop(\"quarter\")\n",
    "    casted_clean_df = _cast_string_columns_to_numeric(spark, clean_df)\\\n",
    "                    .select(all_col_names)\\\n",
    "                    .withColumn(label_col_name, when(col(label_col_name) > 0, 1).otherwise(0))\\\n",
    "                    .fillna(float(0))\n",
    "    return casted_clean_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Part\n",
    "### Run ETL\n",
    "#### 1. Add additional Spark settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GPU run, set to true\n",
    "spark.conf.set(\"spark.rapids.sql.enabled\", \"true\")\n",
    "# CPU run, set to false, it can only make ETL run on CPU when is_save_dataset=True.\n",
    "# spark.conf.set(\"spark.rapids.sql.enabled\", \"false\")\n",
    "spark.conf.set(\"spark.sql.files.maxPartitionBytes\", \"1G\")\n",
    "spark.conf.set(\"spark.rapids.sql.explain\", \"ALL\")\n",
    "spark.conf.set(\"spark.rapids.sql.batchSizeBytes\", \"512M\")\n",
    "spark.conf.set(\"spark.rapids.sql.reader.batchSizeBytes\", \"768M\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Read Raw Data and Run ETL Process, Save the Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ETL takes 135.9117729663849\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# read raw dataset\n",
    "rawDf = read_raw_csv(spark, orig_raw_path)\n",
    "rawDf.write.parquet(orig_raw_path_csv2parquet, mode='overwrite')\n",
    "rawDf = spark.read.parquet(orig_raw_path_csv2parquet)\n",
    "\n",
    "acq = extract_acq_columns(rawDf)\n",
    "perf = extract_perf_columns(rawDf)\n",
    "\n",
    "# run main function to process data\n",
    "out = run_mortgage(spark, perf, acq)\n",
    "\n",
    "# save processed data\n",
    "if is_save_dataset:\n",
    "    start = time.time()\n",
    "    out.write.parquet(output_path_data, mode=\"overwrite\")\n",
    "    end = time.time()\n",
    "    print(\"ETL takes {}\".format(end - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## XGBoost Spark with GPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Import ML Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost.spark import SparkXGBClassifier, SparkXGBClassifierModel\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Create Data Reader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure it runs on GPU\n",
    "spark.conf.set(\"spark.rapids.sql.enabled\", \"true\")\n",
    "\n",
    "reader = spark.read"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Specify the Data Schema and Load the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "label = \"delinquency_12\"\n",
    "schema = StructType([\n",
    "    StructField(\"orig_channel\", FloatType()),\n",
    "    StructField(\"first_home_buyer\", FloatType()),\n",
    "    StructField(\"loan_purpose\", FloatType()),\n",
    "    StructField(\"property_type\", FloatType()),\n",
    "    StructField(\"occupancy_status\", FloatType()),\n",
    "    StructField(\"property_state\", FloatType()),\n",
    "    StructField(\"product_type\", FloatType()),\n",
    "    StructField(\"relocation_mortgage_indicator\", FloatType()),\n",
    "    StructField(\"seller_name\", FloatType()),\n",
    "    StructField(\"mod_flag\", FloatType()),\n",
    "    StructField(\"orig_interest_rate\", FloatType()),\n",
    "    StructField(\"orig_upb\", DoubleType()),\n",
    "    StructField(\"orig_loan_term\", IntegerType()),\n",
    "    StructField(\"orig_ltv\", FloatType()),\n",
    "    StructField(\"orig_cltv\", FloatType()),\n",
    "    StructField(\"num_borrowers\", FloatType()),\n",
    "    StructField(\"dti\", FloatType()),\n",
    "    StructField(\"borrower_credit_score\", FloatType()),\n",
    "    StructField(\"num_units\", IntegerType()),\n",
    "    StructField(\"zip\", IntegerType()),\n",
    "    StructField(\"mortgage_insurance_percent\", FloatType()),\n",
    "    StructField(\"current_loan_delinquency_status\", IntegerType()),\n",
    "    StructField(\"current_actual_upb\", FloatType()),\n",
    "    StructField(\"interest_rate\", FloatType()),\n",
    "    StructField(\"loan_age\", FloatType()),\n",
    "    StructField(\"msa\", FloatType()),\n",
    "    StructField(\"non_interest_bearing_upb\", FloatType()),\n",
    "    StructField(label, IntegerType()),\n",
    "])\n",
    "features = [ x.name for x in schema if x.name != label ]\n",
    "\n",
    "if is_save_dataset:\n",
    "    # load dataset from file\n",
    "    etlDf = reader.parquet(output_path_data)\n",
    "    splits = etlDf.randomSplit([0.8, 0.2])\n",
    "    train_data = splits[0]\n",
    "    test_data = splits[1]\n",
    "else:\n",
    "    # use Dataframe from ETL directly\n",
    "    splits = out.randomSplit([0.8, 0.2])\n",
    "    train_data = splits[0]\n",
    "    test_data = splits[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This sample uses 1 worker(GPU) to run XGBoost training, you can change according to your GPU resources\n",
    "params = { \n",
    "    \"tree_method\": \"gpu_hist\",\n",
    "    \"grow_policy\": \"depthwise\",\n",
    "    \"num_workers\": 1,\n",
    "    \"device\": \"cuda\",\n",
    "}\n",
    "params['features_col'] = features\n",
    "params['label_col'] = label\n",
    "    \n",
    "classifier = SparkXGBClassifier(**params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training takes 18.92583155632019 seconds\n"
     ]
    }
   ],
   "source": [
    "def with_benchmark(phrase, action):\n",
    "    start = time.time()\n",
    "    result = action()\n",
    "    end = time.time()\n",
    "    print(\"{} takes {} seconds\".format(phrase, end - start))\n",
    "    return result\n",
    "model = with_benchmark(\"Training\", lambda: classifier.fit(train_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.write().overwrite().save(output_path_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = SparkXGBClassifierModel().load(output_path_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transformation takes 8.959877967834473 seconds\n",
      "+--------------+--------------------+--------------------+----------+\n",
      "|delinquency_12|       rawPrediction|         probability|prediction|\n",
      "+--------------+--------------------+--------------------+----------+\n",
      "|             0|[7.92072248458862...|[0.99963699193904...|       0.0|\n",
      "|             0|[7.92072248458862...|[0.99963699193904...|       0.0|\n",
      "|             0|[8.43130302429199...|[0.99978211015695...|       0.0|\n",
      "|             0|[8.20779895782470...|[0.99972755435737...|       0.0|\n",
      "|             0|[8.885986328125,-...|[0.99986170543706...|       0.0|\n",
      "+--------------+--------------------+--------------------+----------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def transform():\n",
    "    result = loaded_model.transform(test_data).cache()\n",
    "    result.foreachPartition(lambda _: None)\n",
    "    return result\n",
    "result = with_benchmark(\"Transformation\", transform)\n",
    "result.select(label, \"rawPrediction\", \"probability\", \"prediction\").show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation takes 0.6158628463745117 seconds\n",
      "Accuracy is 0.9861453808970397\n"
     ]
    }
   ],
   "source": [
    "accuracy = with_benchmark(\n",
    "    \"Evaluation\",\n",
    "    lambda: MulticlassClassificationEvaluator().setLabelCol(label).evaluate(result))\n",
    "print(\"Accuracy is \" + str(accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  },
  "name": "gpu-mortgage",
  "notebookId": 4440374682851873
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
